#!/usr/bin/env python3

# Copyright (c) 2024, United States Government, as represented by the
# Administrator of the National Aeronautics and Space Administration.
#
# All rights reserved.
#
# The Astrobee platform is licensed under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with the
# License. You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

"""
Analyze bad pixels:
- Analyze Bayer format images from the specified cameras (e.g., nav, dock) and bags and use the
  resulting statistics to configure instances of the specified types of bad pixel correctors
  (e.g., BiasCorrector, NeighborMeanCorrector). Each corrector is written as a *.json file that may also
  reference additional data files (e.g., BiasCorrector requires a *_bias.png file).
- Report on the performance of each generated corrector using another image reading pass through the
  bags. This allows plotting the error statistics between raw vs. corrected images, as well as
  displaying detail views of example raw vs. corrected images. The information is written to a PDF
  report.

Using the correctors: In the bag_processing package, script rosbag_debayer.py by default analyzes and
generates a new corrector from the bags it is processing every time it is run. If desired, you can
tell it to use a corrector pre-generated by analyze_bad_pixels instead, using its `--load-corrector`
option. This has the following advantages:
- It avoids requiring rosbag_debayer.py to do a slow initial pass through its input for bad pixel
  analysis.
- You may feel more comfortable using a bad pixel corrector after reviewing its performance report.
  (However, note that the optimal corrector configuration could change over time even for the same
  camera due to ongoing radiation damage or other factors, so it may be better to configure a new
  corrector for each activity.)

Caveats:
- The optimal corrector configuration is likely to vary across bags with different camera exposure
  settings, as when tuned to operate in darker modules like NOD2 or USL. (This hasn't been confirmed
  yet.)  You may want to take extra care to use separate bad pixel correctors for bags that have
  different exposure settings.
- Calibrating bad pixel correction with typical data from Astrobee ops requires enough images (say
  30+, not well tested yet), with enough camera motion so that local texture from the scene is not
  "burned in" to sections of the corrected images. Or if you happen to have uniformly dark frames
  with minimal texture, you may need fewer of those. If your bag doesn't have enough images, you may
  want to supplement it with earlier bags to have enough imagery to work with.
"""

import argparse
import itertools
import pathlib
import subprocess
import sys
import time
from typing import Iterable, List, Optional, Type, Union

from matplotlib import pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

from bag_processing import pixel_utils as pu

DESIRED_SAMPLE_IMAGES = 5
PAGE_DIMENSIONS = (11, 8.5)
FONT_SIZE = 9
ALL_CORRECTORS = ["BiasCorrector", "NeighborMeanCorrector"]


def get_image_source(
    bags: List[str], topic: str, image_ratio: int
) -> Iterable[pu.BayerImage]:
    """
    Return a source of images read from `bags` on `topic`.
    :param image_ratio: Process 1 out of every image_ratio images.
    """
    images: Iterable[pu.BayerImage] = pu.ImageSourceBagPaths(bags, topic=topic)
    if image_ratio != 1:
        images = itertools.islice(images, 0, None, image_ratio)
    return images


def report(
    bags: List[str],
    cam: str,
    topic: str,
    output_dir: pathlib.Path,
    image_ratio: int,
    stats: pu.DebugStats,
    corrector: pu.BadPixelCorrector,
) -> None:
    "Write PDF report about `corrector` results applied to `cam` images in `bags`."
    levels = pu.get_levels(stats)

    def all_images():
        return get_image_source(bags, topic, image_ratio)

    t0 = time.time()
    print(
        f"Analyzing {cam} images corrected by {corrector.__class__.__name__}... ",
        end="",
    )
    sys.stdout.flush()
    stats2 = pu.DebugStatsAccumulator.get_image_stats_parallel(
        all_images(), preprocess=corrector
    )
    assert stats2 is not None  # should have enough data if we got to this point
    t1 = time.time()
    print(f"done in {t1 - t0:.1f}s")

    plt.rc("font", size=FONT_SIZE)

    report_path = output_dir / f"{corrector.__class__.__name__}_{cam}_cam_report.pdf"
    with PdfPages(report_path) as pdf:
        data_to_plot = {
            "Raw images": stats,
            "Corrected images": stats2,
        }
        fig, axes = plt.subplots(1, len(data_to_plot), figsize=PAGE_DIMENSIONS)
        for i, (title, stats_i) in enumerate(data_to_plot.items()):
            pu.plot_mean_vs_rms_err(axes[i], stats_i, title)
        page_title = f"bags={str(bags)[:80]}\ntopic='{topic}'"
        if image_ratio != 1:
            page_title += f" 1 out of every {image_ratio} images"
        fig.suptitle(page_title)
        fig.tight_layout()
        pdf.savefig()
        plt.close()

        num_sample_images = min(DESIRED_SAMPLE_IMAGES, stats.count)
        sample_step = stats.count // num_sample_images
        selected_images = itertools.islice(
            enumerate(all_images()), 0, None, sample_step
        )
        for i, image in selected_images:
            pu.plot_image_correction_example(image, corrector, levels)
            fig = plt.gcf()
            fig.set_size_inches(PAGE_DIMENSIONS)
            fig.suptitle(
                f"Frame {i + 1} / {stats.count}: detail of image center (15% x 15%)"
            )
            fig.tight_layout()
            pdf.savefig()
            plt.close()
    print(f"Wrote to {report_path}")


def analyze_bad_pixels(
    bags: List[str],
    cameras: List[str],
    correctors: List[str],
    topic_template: str,
    output_dir: pathlib.Path,
    bias_stats_only: bool,
    image_ratio: int,
    no_report: bool,
) -> None:
    "Driver for analysis code."
    subprocess.run(["mkdir", "-p", str(output_dir)], check=True)

    if image_ratio != 1:
        print(f"Processing 1 out of every {image_ratio} images")

    for cam in cameras:
        if bias_stats_only:
            accum_class: Union[
                Type[pu.BiasStatsAccumulator], Type[pu.DebugStatsAccumulator]
            ] = pu.BiasStatsAccumulator
        else:
            accum_class = pu.DebugStatsAccumulator

        t0 = time.time()
        print(
            f"Analyzing {cam} images raw... ",
            end="",
        )
        sys.stdout.flush()

        topic = topic_template.format(cam=cam)
        images = get_image_source(bags, topic, image_ratio)
        stats = accum_class.get_image_stats_parallel(images)
        note = f"bags={bags} topic='{topic}'"
        t1 = time.time()
        print(f"done in {t1 - t0:.1f}s")

        if stats is None:
            print(f"Not enough {cam} images were found, not generating correctors")
            continue

        corrector_classes = [
            pu.BadPixelCorrector.get_classes(corrector_name)[0]
            for corrector_name in correctors
        ]
        for corrector_class in corrector_classes:
            corrector = corrector_class.from_image_stats(stats, note)
            corrector.save(output_dir / f"{corrector_class.__name__}_{cam}_cam.json")

            if not no_report and isinstance(stats, pu.DebugStats):
                report(
                    bags=bags,
                    cam=cam,
                    topic=topic,
                    output_dir=output_dir,
                    image_ratio=image_ratio,
                    stats=stats,
                    corrector=corrector,
                )


class CustomFormatter(
    argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter
):
    "Formatter with multiple mixins."


def main(argv: Optional[List[str]] = None) -> None:
    "Parse args and call analyze_bad_pixels()."
    if argv is None:
        argv = sys.argv
    parser = argparse.ArgumentParser(
        prog=pathlib.Path(argv[0]).name,
        description=__doc__,
        formatter_class=CustomFormatter,
    )
    parser.add_argument(
        "inbag",
        nargs="+",
        help="bags to read images from",
    )
    parser.add_argument(
        "-c",
        "--cameras",
        nargs="+",
        choices=["nav", "dock"],
        default=["nav"],
        help="cameras to fill in for '{cam}' in --topic; each camera processed separately",
    )
    parser.add_argument(
        "--correctors",
        nargs="+",
        choices=ALL_CORRECTORS,
        default=ALL_CORRECTORS,
        help="types of correctors to generate and analyze",
    )
    parser.add_argument(
        "-t",
        "--topic",
        help="process bayer images on this topic; '{cam}' filled based on --cameras",
        default="/hw/cam_{cam}_bayer",
    )
    parser.add_argument(
        "-o",
        "--output-dir",
        default="analyze_bad_pixels_output",
        help="output directory",
    )
    parser.add_argument(
        "--bias-stats-only",
        action="store_true",
        default=False,
        help="calculate BiasStats and BiasCorrector only; modest speedup, disables report",
    )
    parser.add_argument(
        "-r",
        "--image-ratio",
        type=int,
        default=1,
        help="process 1 out of every IMAGE_RATIO images for faster testing",
    )
    parser.add_argument(
        "--no-report",
        action="store_true",
        default=False,
        help="disable report and analysis of corrected images",
    )
    args = parser.parse_args(argv[1:])

    output_dir = pathlib.Path(args.output_dir)
    if output_dir.exists():
        parser.error(f"{output_dir} exists; not overwriting")

    if args.bias_stats_only:
        correctors = ["BiasCorrector"]
        if correctors != args.correctors:
            print(
                f"--bias-stats-only was specified, forcing correctors to: {correctors}"
            )
    else:
        correctors = args.correctors

    analyze_bad_pixels(
        bags=args.inbag,
        cameras=args.cameras,
        correctors=correctors,
        topic_template=args.topic,
        output_dir=output_dir,
        bias_stats_only=args.bias_stats_only,
        image_ratio=args.image_ratio,
        no_report=args.no_report,
    )
